{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate,\n",
    "                beam_width = 5):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.GRUCell(size_layer, reuse=reuse)\n",
    "        \n",
    "        def attention(encoder_out, seq_len, reuse=False):\n",
    "            attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_units = size_layer, \n",
    "                                                                    memory = encoder_out,\n",
    "                                                                    memory_sequence_length = seq_len)\n",
    "            return tf.contrib.seq2seq.AttentionWrapper(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells(reuse) for _ in range(num_layers)]), \n",
    "                attention_mechanism = attention_mechanism,\n",
    "                attention_layer_size = size_layer)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))\n",
    "        \n",
    "        encoder_out, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "            inputs = tf.nn.embedding_lookup(embeddings, self.X),\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        \n",
    "        with tf.variable_scope('decode'):\n",
    "            decoder_cells = attention(encoder_out, self.X_seq_len)\n",
    "            states = decoder_cells.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state)\n",
    "\n",
    "            training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                    inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                    sequence_length = self.Y_seq_len,\n",
    "                    time_major = False)\n",
    "            training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                    cell = decoder_cells,\n",
    "                    helper = training_helper,\n",
    "                    initial_state = states,\n",
    "                    output_layer = dense)\n",
    "            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                    decoder = training_decoder,\n",
    "                    impute_finished = True,\n",
    "                    maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "            self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        with tf.variable_scope('decode', reuse=True):\n",
    "            encoder_out_tiled = tf.contrib.seq2seq.tile_batch(encoder_out, beam_width)\n",
    "            encoder_state_tiled = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)\n",
    "            X_seq_len_tiled = tf.contrib.seq2seq.tile_batch(self.X_seq_len, beam_width)\n",
    "            decoder_cell = attention(encoder_out_tiled, X_seq_len_tiled, reuse=True)\n",
    "            states = decoder_cell.zero_state(batch_size * beam_width, tf.float32).clone(\n",
    "                    cell_state = encoder_state_tiled)\n",
    "            predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n",
    "                cell = decoder_cell,\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS,\n",
    "                initial_state = states,\n",
    "                beam_width = beam_width,\n",
    "                output_layer = dense,\n",
    "                length_penalty_weight = 0.0)\n",
    "            predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = False,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "            self.fast_result = predicting_decoder_output.predicted_ids[:, :, 0]\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From <ipython-input-7-f60202f051eb>:12: GRUCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.GRUCell, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-f60202f051eb>:33: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-f60202f051eb>:36: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:559: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:565: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:575: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn.py:244: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/contrib/seq2seq/python/ops/beam_search_decoder.py:971: to_int64 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.cast` instead.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[17978, 17978, 17978,  8641, 24680, 24680, 24680, 24680, 24680,\n",
       "         24680, 24680, 24680, 24680, 24680, 24680, 15938,  5340, 12951,\n",
       "         12951, 12951, 12951, 12951,   837,   837,   837,   837,   837,\n",
       "           837, 25762, 25762, 25762, 25762, 31802,  8235,  8235, 16601,\n",
       "         30709, 30709, 30709, 30709, 30709, 21454, 21454, 21454, 21454,\n",
       "          5801, 27281, 12363, 12363, 12363, 12363, 12363, 12363, 12363,\n",
       "         10247, 10247, 10247, 10247,  4951,  4951,  4951,  4951,  4951,\n",
       "         28836, 28836, 28836, 21485, 23206, 23206, 23206, 23206,  7502],\n",
       "        [10552, 10552, 10552, 10552, 31545, 12915, 12915, 12915, 12362,\n",
       "         24102, 24102, 24102, 31617, 20089, 20089, 20089, 27026, 27026,\n",
       "          1299, 25489, 25489, 23610, 23610, 23610,  5026,  5026,  5026,\n",
       "          5026,  5026,  5026,  5026,  5026, 26098, 26098, 26098,  8896,\n",
       "          8896, 13024, 13024, 13024, 26098, 16926, 16926, 24488, 24488,\n",
       "         24488, 24488, 24488,  1770,  1770,  1770,  1770,  1770,  1770,\n",
       "          1770,  1770,  1770,  1770,  1770,  1770,  1770,  1770,  1770,\n",
       "          1770,  1770,  1770,  1770,  1770,  1770,  1770, 29707, 12766],\n",
       "        [11867, 13536,  2818,  2818,  2818,  2818, 30915, 30915, 10623,\n",
       "         10623, 10623, 10623, 10623, 10623, 20871, 13949, 13949, 13949,\n",
       "         13949,  9922,  9922,  9922,  9922,  9922,  9922,  9922,  4729,\n",
       "          4729,  4729,  4729,  3072,  3072,  3072,  3072,  3072,  7076,\n",
       "          7076,  7076,  7076,  7076,  7076,  7076,  7076, 12963, 12963,\n",
       "          7076, 12963, 12963, 12963, 12963, 24256, 24256, 29058, 29058,\n",
       "         29058, 29058, 29058, 29058, 29025, 18055, 18055, 18055, 18055,\n",
       "         18055, 18055,  9849, 17761, 17761, 17761, 17761, 17991, 24325],\n",
       "        [ 4424,  4424, 14165, 14084, 14084, 14084, 14084,  5052, 26143,\n",
       "         26143, 26143, 26143, 26143, 12923, 12923, 30747, 30747, 30747,\n",
       "         27182, 27182, 27182, 31878, 31878, 29210, 31878, 29210, 29210,\n",
       "         22370, 22370,  1819,  1819,  1819, 24160, 24160, 13804, 13804,\n",
       "         13804, 13804, 13804, 13804, 13804, 29416, 29416, 25621, 25621,\n",
       "         25621, 19259, 19259, 19259, 19259,  1009, 27228, 27228, 31412,\n",
       "         31412, 10249, 10249, 10249, 10249, 14590, 24263, 14590, 14590,\n",
       "         24263, 14590, 27544, 27544, 27544, 17580, 17580, 17580, 27380],\n",
       "        [27107, 27107, 27107, 25597, 25597,  1305,  1305, 21927, 21927,\n",
       "         21927, 29925, 29925, 29925, 29925, 29641, 29641, 29641, 29641,\n",
       "         26220, 26220, 26220, 26220, 26220, 26220, 26220,  2472,  2472,\n",
       "          2472, 25517, 25517, 25517, 25517, 25517,  8836,  8836, 27606,\n",
       "         27606,  4393,  4393,  2241,  2241,  2241,  2241, 12029, 12029,\n",
       "         12029, 12029, 12029, 24179, 27172, 27172, 17509, 17509, 17509,\n",
       "         17509, 17509, 17509, 23298, 23298, 22618, 22618, 22618, 22618,\n",
       "         22618, 22618,  1662,  1662, 15705, 15705, 15705, 15705, 21692],\n",
       "        [ 9172,  9172,  9172, 28629, 12377, 28629, 12377, 28629, 19336,\n",
       "         19336, 19336, 19336, 30709, 30709,  6515,  6515, 14053, 14053,\n",
       "         14053, 13684, 13684, 13684, 17921, 17921, 17921, 17921, 14163,\n",
       "         26497, 24368, 24368, 24368, 24368, 24368, 24368, 26293, 26293,\n",
       "         26293, 26293, 31962, 31962, 31962,  1096,  1096,  1096,  1096,\n",
       "         29472, 29472, 29472, 29472, 11642, 11642, 11642, 11642, 11642,\n",
       "         11642, 11642,  6985,  6985,  8361,  8361,  8361, 11283, 18507,\n",
       "         18507, 18507, 18507, 25722, 25722,  6717,  6717,  6717,  6717],\n",
       "        [ 7679,  7679,  5737,  5737, 18576, 18576, 18576, 18576, 19840,\n",
       "         19840, 19840, 19840, 19840, 19840, 19840, 19840, 19840, 18288,\n",
       "         31096, 31096, 31096, 31096, 31096, 31096, 31096, 31096, 12684,\n",
       "         12684,  4116,  4116,  2927,  2927,  2927,  2927,  2927,  2927,\n",
       "          2927, 23364, 23364,  2927,  2927,  2927,  2927,  2927,  2927,\n",
       "         23364, 23364,  2927,  2927,  2927,  2927,  2927, 24488, 24488,\n",
       "          2927, 24488, 24488, 24488, 24488,  4096,  4096, 24488, 24488,\n",
       "         20923, 20923, 20923, 20923, 29644, 29644, 29644, 29644,  4697],\n",
       "        [21891, 21891,  3325,  3325,  3325, 12500,  6856,  6856,  6856,\n",
       "          6856, 15128, 15128, 15128, 15128, 15128, 15128, 15128,  5779,\n",
       "          5779,  5779,  2414,  9630,  9630,  9630,  1271,  1271,  8146,\n",
       "          8146,  8146,  8146,   571,   571,   571,   571,   571,  1757,\n",
       "          1757,  1757,  1998,  1998,  1998,  1998,  4404,  4404, 30835,\n",
       "         30835, 30835, 30835, 31512, 31512, 31512, 31512, 22308, 22308,\n",
       "         22308, 22308, 19578, 13788, 13788, 13788, 13788, 13788, 13788,\n",
       "          6408,  6408, 20913, 23621, 23621, 23621, 23621,  7031,  7031],\n",
       "        [ 4925,  4925, 13604, 13604, 13604, 16544, 16544, 16544, 16544,\n",
       "         16544, 16544, 16544, 29554, 29554, 29554, 29554, 26381, 29554,\n",
       "         26381, 20252, 20252, 20252, 20252, 20252, 24409, 24409, 24409,\n",
       "         19330,  4349, 19330, 19330, 19330, 27655, 27655, 13303, 27655,\n",
       "         13303, 22948, 22948, 13699, 13699,  5056,  5056, 27710, 27710,\n",
       "         25981, 25981, 25981, 27080, 27080, 27080, 27080, 21865, 21865,\n",
       "         21865, 21865, 11191, 11191, 11191, 11191, 11191, 11191, 11191,\n",
       "         11191, 11191, 11191, 11191, 11191, 11191, 11191, 11191, 11191],\n",
       "        [22031, 22031, 22031, 22031, 22031, 12795, 12795, 12795, 12795,\n",
       "         12795, 23284, 23284, 20913, 20913, 20913, 20913, 20913, 20913,\n",
       "         20913, 20913, 20913, 20913, 20913, 20913, 20913, 20913, 20913,\n",
       "         20913,  3622,  3622,  3622,  3363,  3363,  3363,   360, 20913,\n",
       "         20913, 20913, 20913, 20913, 20913, 20913, 20913, 20913, 20913,\n",
       "         20913, 20913, 20913, 20913, 20913, 20913, 20913, 20913, 21165,\n",
       "         20913, 21165, 21165, 21165, 21165, 21165,  6294, 16768, 16768,\n",
       "         16365, 16365, 16365, 18402, 18402, 18402, 18402, 18402, 22472]],\n",
       "       dtype=int32), 10.372692, 0.0]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.39it/s, accuracy=0.284, cost=4.39]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.82it/s, accuracy=0.333, cost=3.99]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 5.374818, training avg acc 0.204933\n",
      "epoch 1, testing avg loss 4.256302, testing avg acc 0.298886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.365, cost=3.65]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.92it/s, accuracy=0.355, cost=3.77]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 3.885117, training avg acc 0.336039\n",
      "epoch 2, testing avg loss 3.840724, testing avg acc 0.349001\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.40it/s, accuracy=0.419, cost=3.22]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s, accuracy=0.414, cost=3.33]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 3.423263, training avg acc 0.384755\n",
      "epoch 3, testing avg loss 3.623492, testing avg acc 0.370631\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.405, cost=3.14]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.90it/s, accuracy=0.398, cost=3.16]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 3.228974, training avg acc 0.405378\n",
      "epoch 4, testing avg loss 3.629654, testing avg acc 0.372957\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.411, cost=3.1] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s, accuracy=0.392, cost=3.24]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 3.334839, training avg acc 0.389818\n",
      "epoch 5, testing avg loss 3.608245, testing avg acc 0.376056\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.442, cost=2.88]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.446, cost=3.12]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 3.148405, training avg acc 0.411932\n",
      "epoch 6, testing avg loss 3.540085, testing avg acc 0.385920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:55<00:00,  2.39it/s, accuracy=0.132, cost=7.24] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.90it/s, accuracy=0.129, cost=6.57]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 4.979289, training avg acc 0.308916\n",
      "epoch 7, testing avg loss 7.174543, testing avg acc 0.135664\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:47<00:00,  2.41it/s, accuracy=0.0706, cost=8.74]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.99it/s, accuracy=0.0376, cost=8.12]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 7.963271, training avg acc 0.107955\n",
      "epoch 8, testing avg loss 8.641388, testing avg acc 0.075744\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:50<00:00,  2.40it/s, accuracy=0.126, cost=6.51] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.90it/s, accuracy=0.124, cost=6.09]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 7.152506, training avg acc 0.116101\n",
      "epoch 9, testing avg loss 6.580836, testing avg acc 0.134291\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.39it/s, accuracy=0.136, cost=6.33]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.151, cost=5.74]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 6.888044, training avg acc 0.132050\n",
      "epoch 10, testing avg loss 6.357913, testing avg acc 0.148491\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.39it/s, accuracy=0.136, cost=6.49]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.145, cost=6.08]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 7.760569, training avg acc 0.132505\n",
      "epoch 11, testing avg loss 6.507605, testing avg acc 0.149597\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.40it/s, accuracy=0.171, cost=5.64]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.172, cost=5.3] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 5.968809, training avg acc 0.163244\n",
      "epoch 12, testing avg loss 5.705050, testing avg acc 0.175293\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:54<00:00,  2.39it/s, accuracy=0.197, cost=5.13]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.183, cost=4.91]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 5.328347, training avg acc 0.188188\n",
      "epoch 13, testing avg loss 5.240874, testing avg acc 0.198084\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.225, cost=4.78]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.215, cost=4.63]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 4.905855, training avg acc 0.213974\n",
      "epoch 14, testing avg loss 4.929050, testing avg acc 0.219257\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.39it/s, accuracy=0.25, cost=4.41] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.87it/s, accuracy=0.253, cost=4.47]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 4.572335, training avg acc 0.238671\n",
      "epoch 15, testing avg loss 4.688868, testing avg acc 0.239015\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:52<00:00,  2.40it/s, accuracy=0.278, cost=4.07]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.93it/s, accuracy=0.263, cost=4.28]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 4.276408, training avg acc 0.263065\n",
      "epoch 16, testing avg loss 4.491760, testing avg acc 0.256634\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.317, cost=3.7] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.94it/s, accuracy=0.28, cost=4.13] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 4.003881, training avg acc 0.287369\n",
      "epoch 17, testing avg loss 4.342401, testing avg acc 0.274190\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:51<00:00,  2.40it/s, accuracy=0.0984, cost=8.27] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.92it/s, accuracy=0.108, cost=8.07] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 7.807547, training avg acc 0.179802\n",
      "epoch 18, testing avg loss 8.499818, testing avg acc 0.098730\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:49<00:00,  2.41it/s, accuracy=0.129, cost=7.12]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.91it/s, accuracy=0.118, cost=6.6] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 7.928953, training avg acc 0.117910\n",
      "epoch 19, testing avg loss 7.259023, testing avg acc 0.134780\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:47<00:00,  2.42it/s, accuracy=0.155, cost=6.13]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:08<00:00,  4.87it/s, accuracy=0.145, cost=5.69]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 6.658662, training avg acc 0.149770\n",
      "epoch 20, testing avg loss 6.352009, testing avg acc 0.159551\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import bleu_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:37<00:00,  1.08it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.003980886"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
